In [ ]:
epochs = 15
Recap: The previous tutorial looked at building a basic SplitNN, where an NN was split into two segments on two seperate hosts. However, there is a lot more that we can do with this technique. An NN can be split any number of times without affecting the accuracy of the model.
Description: Here we define a class which can procees a SplitNN of any number of layers. All it needs is a list of distributed models and their optimizers.
In this tutorial, we demonstrate the SplitNN class with a 3 segment distribution [1]. This time;
We use the exact same model as we used in the previous tutorial, only this time we are splitting over 3 hosts, not two. However, we see the same loss being reported as there is no reduction in accuracy when training in this way. While we only use 3 models this can be done for any arbitrary number of models.
Author:
In [ ]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim
import syft as sy
hook = sy.TorchHook(torch)
In [ ]:
class SplitNN(torch.nn.Module):
def __init__(self, models, optimizers):
self.models = models
self.optimizers = optimizers
self.outputs = [None]*len(self.models)
self.inputs = [None]*len(self.models)
super().__init__()
def forward(self, x):
self.inputs[0] = x
self.outputs[0] = self.models[0](self.inputs[0])
for i in range(1, len(self.models)):
self.inputs[i] = self.outputs[i-1].detach().requires_grad_()
if self.outputs[i-1].location != self.models[i].location:
self.inputs[i] = self.inputs[i].move(self.models[i].location).requires_grad_()
self.outputs[i] = self.models[i](self.inputs[i])
return self.outputs[-1]
def backward(self):
for i in range(len(self.models)-2, -1, -1):
grad_in = self.inputs[i+1].grad.copy()
if self.outputs[i].location != self.inputs[i+1].location:
grad_in = grad_in.move(self.outputs[i].location)
self.outputs[i].backward(grad_in)
def zero_grads(self):
for opt in self.optimizers:
opt.zero_grad()
def step(self):
for opt in self.optimizers:
opt.step()
def train(self):
for model in self.models:
model.train()
def eval(self):
for model in self.models:
model.eval()
@property
def location(self):
return self.models[0].location if self.models and len(self.models) else None
In [ ]:
# Data preprocessing
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)),
])
trainset = datasets.MNIST('mnist', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
In [ ]:
torch.manual_seed(0)
# Define our model segments
input_size = 784
hidden_sizes = [128, 640]
output_size = 10
models = [
nn.Sequential(
nn.Linear(input_size, hidden_sizes[0]),
nn.ReLU(),
),
nn.Sequential(
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
nn.ReLU(),
),
nn.Sequential(
nn.Linear(hidden_sizes[1], output_size),
nn.LogSoftmax(dim=1)
)
]
# Create optimisers for each segment and link to them
optimizers = [
optim.SGD(model.parameters(), lr=0.03,)
for model in models
]
# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
claire = sy.VirtualWorker(hook, id="claire")
# Send Model Segments to model locations
model_locations = [alice, bob, claire]
for model, location in zip(models, model_locations):
model.send(location)
#Instantiate a SpliNN class with our distributed segments and their respective optimizers
splitNN = SplitNN(models, optimizers)
In [ ]:
def train(x, target, splitNN):
#1) Zero our grads
splitNN.zero_grads()
#2) Make a prediction
pred = splitNN.forward(x)
#3) Figure out how much we missed by
criterion = nn.NLLLoss()
loss = criterion(pred, target)
#4) Backprop the loss on the end layer
loss.backward()
#5) Feed Gradients backward through the nework
splitNN.backward()
#6) Change the weights
splitNN.step()
return loss
In [ ]:
for i in range(epochs):
running_loss = 0
splitNN.train()
for images, labels in trainloader:
images = images.send(models[0].location)
images = images.view(images.shape[0], -1)
labels = labels.send(models[-1].location)
loss = train(images, labels, splitNN)
running_loss += loss.get()
else:
print("Epoch {} - Training loss: {}".format(i, running_loss/len(trainloader)))
In [ ]:
def test(model, dataloader, dataset_name):
model.eval()
correct = 0
with torch.no_grad():
for data, target in dataloader:
data = data.view(data.shape[0], -1).send(model.location)
output = model(data).get()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.data.view_as(pred)).sum()
print("{}: Accuracy {}/{} ({:.0f}%)".format(dataset_name,
correct,
len(dataloader.dataset),
100. * correct / len(dataloader.dataset)))
In [ ]:
testset = datasets.MNIST('mnist', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
test(splitNN, testloader, "Test set")
test(splitNN, trainloader, "Train set")